Learning and inference of latents with MC_MAZE_SMALL data

The code below learns and infers latents with MC_MAZE_SMALL data.

import numpy as np
import plotly.graph_objects as go

from dandi.dandiapi import DandiAPIClient
from pynwb import NWBHDF5IO

import ssm.inference
import ssm.learning
import ssm.neural_latents.utils
import ssm.neural_latents.plotting

Define variables

Here I am setting a small value to variable max_iter, in order to build this documentation quickly. You may want to set this variable to max_iter=2000 in combination with a small error tolerance tol=1e-3.

In addition, you may want to plot all the data, and not just a small time interval between from_time = 100.0 and to_time = 130.0, by setting from_time = -np.inf and to_time = np.inf.

# data
get_data_from_Dandi = True
dandiset_ID = "000140"
dandi_filepath = "sub-Jenkins/sub-Jenkins_ses-small_desc-train_behavior+ecephys.nwb"
local_filepath = f"../../../../projects/lds_neuralLatents_MC_MAZE_SMALL/data/{dandiset_ID}/sub-Jenkins/sub-Jenkins_ses-small_desc-train_behavior+ecephys.nwb"
bin_size = 0.02

# plot
events_names = ["start_time", "target_on_time", "go_cue_time",
                "move_onset_time", "stop_time"]
events_linetypes = ["dot", "dash", "dashdot", "longdash", "solid"]
events_colors_spikes = ["white", "white", "white", "white", "white"]
events_colors_latents = ["black", "black", "black", "black", "black"]
cb_alpha = 0.3
from_time = 100.0
to_time = 130.0

# model
n_latents = 10

# estimation initial conditions
sigma_B = 0.1
sigma_Z = 0.1
sigma_Q = 0.1
sigma_R = 0.1
sigma_m0 = 0.1
sigma_V0 = 0.1

# estimation parameters
# max_iter = 2000
max_iter = 5
tol = 1e-1
vars_to_estimate = {"B": True, "Q": True, "Z": True, "R": True,
                    "m0": True, "V0": True, }

Download data

if get_data_from_Dandi:
    with DandiAPIClient() as client:
        asset = client.get_dandiset(dandiset_ID,
                                    "draft").get_asset_by_path(dandi_filepath)
        s3_path = asset.get_content_url(follow_redirects=1, strip_query=True)
        io = NWBHDF5IO(s3_path, mode="r", driver="ros3")
        nwbfile = io.read()
        units_df = nwbfile.units.to_dataframe()
        trials_df = nwbfile.intervals["trials"].to_dataframe()
else:
    with NWBHDF5IO(local_filepath, 'r') as io:
        nwbfile = io.read()
        units_df = nwbfile.units.to_dataframe()
        trials_df = nwbfile.intervals["trials"].to_dataframe()


# n_clusters
n_clusters = units_df.shape[0]

Bin spikes

# continuous spikes times
continuous_spikes_times = [None for n in range(n_clusters)]
for n in range(n_clusters):
    continuous_spikes_times[n] = units_df.iloc[n]['spike_times']

binned_spikes, bin_edges = ssm.neural_latents.utils.bin_spike_times(
    spike_times=continuous_spikes_times, bin_size=bin_size)
bin_centers = (bin_edges[1:] + bin_edges[:-1])/2
transformed_binned_spikes = np.sqrt(binned_spikes + 0.5)

# clip data to plot
first_index = np.where(bin_centers >= from_time)[0][0]
last_index = np.where(bin_centers <= to_time)[0][-1]
to_plot_slice = slice(first_index, last_index)
bin_centers_to_plot = bin_centers[to_plot_slice]
trials_df = trials_df[np.logical_and(
    trials_df['start_time'] >= from_time,
    trials_df['stop_time'] <= to_time,
)]

Plot binned spikes

fig = go.Figure()
trace = go.Heatmap(x=bin_centers_to_plot,
                   z=transformed_binned_spikes[:, to_plot_slice],
                   colorbar=dict(title="<b>Sqrt(spike_count+0.5)</b>"))
fig.add_trace(trace)
ssm.neural_latents.plotting.add_events_vlines(
    fig=fig, trials_df=trials_df, events_names=events_names,
    events_linetypes=events_linetypes, events_colors=events_colors_spikes)
fig.update_xaxes(title="Time (sec)")
fig.update_yaxes(title="Cluster Index")
fig


Parameter learning using expectation maximization

B0 = np.diag(np.random.normal(loc=0, scale=sigma_B, size=n_latents))
Z0 = np.random.normal(loc=0, scale=sigma_Z, size=(n_clusters, n_latents))
Q0 = np.diag(np.abs(np.random.normal(loc=0, scale=sigma_Q, size=n_latents)))
R0 = np.diag(np.abs(np.random.normal(loc=0, scale=sigma_R, size=n_clusters)))
m0_0 = np.random.normal(loc=0, scale=sigma_m0, size=n_latents)
V0_0 = np.diag(np.abs(np.random.normal(loc=0, scale=sigma_V0, size=n_latents)))

optim_res = ssm.learning.em_SS_LDS(
    y=transformed_binned_spikes, B0=B0, Q0=Q0, Z0=Z0, R0=R0,
    m0_0=m0_0, V0_0=V0_0, max_iter=max_iter, tol=tol,
    vars_to_estimate=vars_to_estimate,
)
LogLike[0000]=-23669214.200285
LogLike[0001]=1810885.631545
LogLike[0002]=1814935.707620
LogLike[0003]=1819273.379751
LogLike[0004]=1823741.512890

Plot log likelihood vs iteration number

N = len(optim_res["log_like"])
iter_no = np.arange(0, N)
fig = go.Figure()
trace = go.Scatter(x=iter_no,
                   y=optim_res["log_like"],
                   mode="lines+markers")
fig.add_trace(trace)
fig.update_layout(xaxis=dict(title="Iteration Number"),
                  yaxis=dict(title="Lower Bound"))
fig


Kalman filtering

filter_res = ssm.inference.filterLDS_SS_withMissingValues_np(
    y=transformed_binned_spikes, B=optim_res["B"], Q=optim_res["Q"],
    m0=optim_res["m0"], V0=optim_res["V0"], Z=optim_res["Z"], R=optim_res["R"])

Kalman smoothing

smoothing_res = ssm.inference.smoothLDS_SS(
    B=optim_res["B"], xnn=filter_res["xnn"], Pnn=filter_res["Pnn"],
    xnn1=filter_res["xnn1"], Pnn1=filter_res["Pnn1"],
    m0=optim_res["m0"], V0=optim_res["V0"])

Plot smoothed states

o_means_to_plot, o_covs_to_plot = ssm.neural_latents.utils.ortogonalizeMeansAndCovs(
    means=smoothing_res["xnN"][:, :, to_plot_slice],
    covs=smoothing_res["PnN"][:, :, to_plot_slice], Z=optim_res["Z"])

fig = ssm.neural_latents.plotting.plot_latents(
    means=o_means_to_plot,
    covs=o_covs_to_plot,
    bin_centers=bin_centers_to_plot,
    trials_df=trials_df,
    events_names=events_names,
    events_linetypes=events_linetypes,
    events_colors=events_colors_latents,
    cb_alpha=cb_alpha,
    legend_pattern="smoothing_{:d}",
)

fig


Total running time of the script: ( 4 minutes 26.060 seconds)

Gallery generated by Sphinx-Gallery